from builtins import str
from builtins import zip
from builtins import range
import sys
sys.path.insert(1,"../../")
import h2o
from tests import pyunit_utils
from h2o.estimators.gbm import H2OGradientBoostingEstimator


def mycomp(l,r):
    assert len(l) == len(r)
    for i in range(len(l)):
        l_i = [num for num in l[i] if isinstance(num, (int,float))]
        r_i = [num for num in r[i] if isinstance(num, (int,float))]
        zz = list(zip(l_i,r_i))
        print(zz)
        diff = max([abs(li-ri) for li,ri in zz])
        print('diff',diff)
        z = [abs(li-ri)<5e-8 for li,ri in zz]

        assert all(z), str(i) + ":" +  str(z)

def pubdev_2118():
    df = h2o.import_file(pyunit_utils.locate("smalldata/prostate/prostate.csv"))
    df["CAPSULE"]=df["CAPSULE"].asfactor()

    m = H2OGradientBoostingEstimator()
    m.train(x=df.names,y="CAPSULE", training_frame=df, validation_frame=df)

    t = m.gains_lift()
    #print t.cell_values

    exp = [(u'', 1, 0.010526315789473684, 0.9656726269867583, 2.4836601307189543, 2.4836601307189543, 1.0, 0.9710687339136742, 1.0, 0.9710687339136742, 0.026143790849673203, 0.026143790849673203, 148.36601307189542, 148.36601307189542), (u'', 2, 0.021052631578947368, 0.9589343603635085, 2.4836601307189543, 2.4836601307189543, 1.0, 0.9628960864202921, 1.0, 0.9669824101669832, 0.026143790849673203, 0.05228758169934641, 148.36601307189542, 148.36601307189542), (u'', 3, 0.031578947368421054, 0.9507825529565218, 2.4836601307189543, 2.4836601307189543, 1.0, 0.95405241347286, 1.0, 0.9626724112689421, 0.026143790849673203, 0.0784313725490196, 148.36601307189542, 148.36601307189542), (u'', 4, 0.042105263157894736, 0.942267273985621, 2.4836601307189543, 2.4836601307189543, 1.0, 0.9471095592343903, 1.0, 0.9587816982603041, 0.026143790849673203, 0.10457516339869281, 148.36601307189542, 148.36601307189542), (u'', 5, 0.05, 0.930122612318471, 2.4836601307189543, 2.4836601307189543, 1.0, 0.9377545194004764, 1.0, 0.9554616173876999, 0.0196078431372549, 0.12418300653594772, 148.36601307189542, 148.36601307189542), (u'', 6, 0.1, 0.9044146521806046, 2.4836601307189543, 2.4836601307189543, 1.0, 0.9193294296626036, 1.0, 0.9373955235251517, 0.12418300653594772, 0.24836601307189543, 148.36601307189542, 148.36601307189542), (u'', 7, 0.15, 0.8446853133779267, 2.4836601307189543, 2.4836601307189543, 1.0, 0.8751891162377066, 1.0, 0.9166600544293366, 0.12418300653594772, 0.37254901960784315, 148.36601307189542, 148.36601307189542), (u'', 8, 0.2, 0.7961432054656624, 2.4836601307189543, 2.4836601307189543, 1.0, 0.822552825530283, 1.0, 0.8931332472045731, 0.12418300653594772, 0.49673202614379086, 148.36601307189542, 148.36601307189542), (u'', 9, 0.3, 0.6723258400857857, 2.4836601307189543, 2.4836601307189543, 1.0, 0.7354960586993771, 1.0, 0.8405875177028411, 0.24836601307189543, 0.7450980392156863, 148.36601307189542, 148.36601307189542), (u'', 10, 0.4, 0.4587689342692059, 1.6993464052287583, 2.287581699346405, 0.6842105263157895, 0.5612327297355235, 0.9210526315789473, 0.7707488207110117, 0.16993464052287582, 0.9150326797385621, 69.93464052287584, 128.7581699346405), (u'', 11, 0.5, 0.2941654313583412, 0.7843137254901961, 1.9869281045751637, 0.3157894736842105, 0.36138604514289685, 0.8, 0.6888762655973888, 0.0784313725490196, 0.9934640522875817, -21.568627450980394, 98.69281045751637), (u'', 12, 0.6, 0.1936958406847592, 0.06535947712418301, 1.6666666666666667, 0.02631578947368421, 0.23661298765425717, 0.6710526315789473, 0.6134990526068669, 0.006535947712418301, 1.0, -93.4640522875817, 66.66666666666667), (u'', 13, 0.7, 0.11690112068515252, 0.0, 1.4285714285714286, 0.0, 0.15563383142391402, 0.575187969924812, 0.5480897352950165, 0.0, 1.0, -100.0, 42.85714285714286), (u'', 14, 0.8, 0.08004747799144096, 0.0, 1.25, 0.0, 0.09663009317659732, 0.5032894736842105, 0.49165728003021403, 0.0, 1.0, -100.0, 25.0), (u'', 15, 0.9, 0.04735533377423584, 0.0, 1.1111111111111112, 0.0, 0.06444853566528014, 0.4473684210526316, 0.44418964176744363, 0.0, 1.0, -100.0, 11.111111111111116), (u'', 16, 1.0, 0.009748409930340669, 0.0, 1.0, 0.0, 0.03071707428603298, 0.4026315789473684, 0.4028423850193026, 0.0, 1.0, -100.0, 0.0)]
    print(t.cell_values)
    mycomp(exp, t.cell_values)

    t = m.gains_lift(valid=True)
    mycomp(exp, t.cell_values)

    p = m.model_performance(df)
    t = p.gains_lift()
    mycomp(exp, t.cell_values)


    m = H2OGradientBoostingEstimator(nfolds=3, seed=1234)
    m.train(x=df.names,y="CAPSULE", training_frame=df, validation_frame=df)
    t = m.gains_lift(xval=True)
    print(t.cell_values)
    exp2 = [(u'', 1, 0.010526315789473684, 0.9720357890880759, 1.8627450980392157, 1.8627450980392157, 0.75, 0.9760072931854437, 0.75, 0.9760072931854437, 0.0196078431372549, 0.0196078431372549, 86.27450980392157, 86.27450980392157), (u'', 2, 0.021052631578947368, 0.9621849870843923, 1.8627450980392157, 1.8627450980392157, 0.75, 0.9669374439725144, 0.75, 0.971472368578979, 0.0196078431372549, 0.0392156862745098, 86.27450980392157, 86.27450980392157), (u'', 3, 0.031578947368421054, 0.9490492319323595, 1.8627450980392157, 1.8627450980392157, 0.75, 0.9568784717485775, 0.75, 0.9666077363021786, 0.0196078431372549, 0.058823529411764705, 86.27450980392157, 86.27450980392157), (u'', 4, 0.042105263157894736, 0.9331874956273033, 1.8627450980392157, 1.8627450980392157, 0.75, 0.9412334873687302, 0.75, 0.9602641740688165, 0.0196078431372549, 0.0784313725490196, 86.27450980392157, 86.27450980392157), (u'', 5, 0.05, 0.9319212918270888, 2.4836601307189543, 1.9607843137254903, 1.0, 0.9326627755163045, 0.7894736842105263, 0.9559060585078936, 0.0196078431372549, 0.09803921568627451, 148.36601307189542, 96.07843137254903), (u'', 6, 0.1, 0.8704014317587268, 2.2222222222222223, 2.0915032679738563, 0.8947368421052632, 0.9064122080712612, 0.8421052631578947, 0.9311591332895773, 0.1111111111111111, 0.20915032679738563, 122.22222222222223, 109.15032679738563), (u'', 7, 0.15, 0.7994327310905681, 1.3071895424836601, 1.8300653594771241, 0.5263157894736842, 0.8237496530969899, 0.7368421052631579, 0.8953559732253816, 0.06535947712418301, 0.27450980392156865, 30.718954248366014, 83.00653594771241), (u'', 8, 0.2, 0.7409640897307539, 1.8300653594771241, 1.8300653594771241, 0.7368421052631579, 0.7797762313520376, 0.7368421052631579, 0.8664610377570455, 0.0915032679738562, 0.3660130718954248, 83.00653594771241, 83.00653594771241), (u'', 9, 0.3, 0.5970714663081699, 1.5686274509803921, 1.7429193899782136, 0.631578947368421, 0.6593952334847722, 0.7017543859649122, 0.7974391029996212, 0.1568627450980392, 0.5228758169934641, 56.86274509803921, 74.29193899782136), (u'', 10, 0.4, 0.430473955067175, 1.4379084967320264, 1.6666666666666667, 0.5789473684210527, 0.5094600228139071, 0.6710526315789473, 0.7254443329531925, 0.1437908496732026, 0.6666666666666666, 43.790849673202636, 66.66666666666667), (u'', 11, 0.5, 0.3161245366286738, 0.9803921568627452, 1.5294117647058825, 0.39473684210526316, 0.3786448086411368, 0.6157894736842106, 0.6560844280907814, 0.09803921568627451, 0.7647058823529411, -1.9607843137254832, 52.941176470588246), (u'', 12, 0.6, 0.23218508200249308, 0.9150326797385621, 1.4270152505446625, 0.3684210526315789, 0.2812411359528033, 0.5745614035087719, 0.5936105460677851, 0.0915032679738562, 0.8562091503267973, -8.496732026143794, 42.70152505446625), (u'', 13, 0.7, 0.14659241228908762, 0.5228758169934641, 1.2978524743230628, 0.21052631578947367, 0.18397544100232177, 0.5225563909774437, 0.5350912453441474, 0.05228758169934641, 0.9084967320261438, -47.71241830065359, 29.785247432306285), (u'', 14, 0.8, 0.09234093116404647, 0.45751633986928103, 1.1928104575163399, 0.18421052631578946, 0.12081167570481807, 0.48026315789473684, 0.4833062991392313, 0.0457516339869281, 0.954248366013072, -54.2483660130719, 19.281045751633986), (u'', 15, 0.9, 0.04771122966691487, 0.13071895424836602, 1.074800290486565, 0.05263157894736842, 0.06846481525965206, 0.4327485380116959, 0.43721280093038917, 0.013071895424836602, 0.9673202614379085, -86.9281045751634, 7.4800290486565), (u'', 16, 1.0, 0.009938670599098145, 0.32679738562091504, 1.0, 0.13157894736842105, 0.027346092704593806, 0.4026315789473684, 0.39622613010780966, 0.032679738562091505, 1.0, -67.3202614379085, 0.0)]
    mycomp(exp2, t.cell_values)


    p = m.model_performance(df)
    t = p.gains_lift()
    mycomp(exp, t.cell_values)


if __name__ == "__main__":
    pyunit_utils.standalone_test(pubdev_2118)
else:
    pubdev_2118()
